import random
import time
from pathlib import Path

import numpy as np
import torch
import os

# For reproducibility
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
from contextlib import contextmanager
from diffusers import schedulers
from diffusers.models import AutoencoderKL
from loguru import logger
from transformers import BertModel, BertTokenizer
from transformers.modeling_utils import logger as tf_logger
from .modules.controlnet import HunYuanControlNet
from .constants import (
    SAMPLER_FACTORY,
    NEGATIVE_PROMPT,
    TRT_MAX_WIDTH,
    TRT_MAX_HEIGHT,
    TRT_MAX_BATCH_SIZE,
)
from .diffusion.pipeline import StableDiffusionPipeline
from .diffusion.pipeline_controlnet import StableDiffusionControlNetPipeline
from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
from .modules.text_encoder import MT5Embedder
from .utils.tools import set_seeds
from peft import LoraConfig
from PIL import Image
from .lora import load_lora


class Resolution:
    def __init__(self, width, height):
        self.width = width
        self.height = height

    def __str__(self):
        return f"{self.height}x{self.width}"


class ResolutionGroup:
    def __init__(self):
        self.data = [
            Resolution(1024, 1024),  # 1:1
            Resolution(1280, 1280),  # 1:1
            Resolution(1024, 768),  # 4:3
            Resolution(1152, 864),  # 4:3
            Resolution(1280, 960),  # 4:3
            Resolution(768, 1024),  # 3:4
            Resolution(864, 1152),  # 3:4
            Resolution(960, 1280),  # 3:4
            Resolution(1280, 768),  # 16:9
            Resolution(768, 1280),  # 9:16
        ]
        self.supported_sizes = set([(r.width, r.height) for r in self.data])

    def is_valid(self, width, height):
        return (width, height) in self.supported_sizes


STANDARD_RATIO = np.array(
    [
        1.0,  # 1:1
        4.0 / 3.0,  # 4:3
        3.0 / 4.0,  # 3:4
        16.0 / 9.0,  # 16:9
        9.0 / 16.0,  # 9:16
    ]
)
STANDARD_SHAPE = [
    [(1024, 1024), (1280, 1280)],  # 1:1
    [(1280, 960)],  # 4:3
    [(960, 1280)],  # 3:4
    [(1280, 768)],  # 16:9
    [(768, 1280)],  # 9:16
]
STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]


def get_standard_shape(target_W, target_H):
    """
    Map image size to standard size.
    """
    target_ratio = target_W / target_H
    closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
    closest_area_idx = np.argmin(
        np.abs(STANDARD_AREA[closest_ratio_idx] - target_W * target_H)
    )
    width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
    return width, height


def _to_tuple(val):
    if isinstance(val, (list, tuple)):
        if len(val) == 1:
            val = [val[0], val[0]]
        elif len(val) == 2:
            val = tuple(val)
        else:
            raise ValueError(f"Invalid value: {val}")
    elif isinstance(val, (int, float)):
        val = (val, val)
    else:
        raise ValueError(f"Invalid value: {val}")
    return val


def get_pipeline(
    args,
    vae,
    text_encoder,
    tokenizer,
    model,
    device,
    rank,
    embedder_t5,
    infer_mode,
    sampler=None,
    controlnet=None,
):
    """
    Get scheduler and pipeline for sampling. The sampler and pipeline are both
    based on diffusers and make some modifications.

    Returns
    -------
    pipeline: StableDiffusionPipeline
    sampler_name: str
    """
    sampler = sampler or args.sampler

    # Load sampler from factory
    kwargs = SAMPLER_FACTORY[sampler]["kwargs"]
    scheduler = SAMPLER_FACTORY[sampler]["scheduler"]

    # Update sampler according to the arguments
    kwargs["beta_schedule"] = args.noise_schedule
    kwargs["beta_start"] = args.beta_start
    kwargs["beta_end"] = args.beta_end
    kwargs["prediction_type"] = args.predict_type

    # Build scheduler according to the sampler.
    scheduler_class = getattr(schedulers, scheduler)
    scheduler = scheduler_class(**kwargs)
    logger.debug(f"Using sampler: {sampler} with scheduler: {scheduler}")

    # Set timesteps for inference steps.
    scheduler.set_timesteps(args.infer_steps, device)

    # Only enable progress bar for rank 0
    progress_bar_config = {} if rank == 0 else {"disable": True}
    if not controlnet:
        pipeline = StableDiffusionPipeline(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=model,
            scheduler=scheduler,
            feature_extractor=None,
            safety_checker=None,
            requires_safety_checker=False,
            progress_bar_config=progress_bar_config,
            embedder_t5=embedder_t5,
            infer_mode=infer_mode,
        )
    else:
        pipeline = StableDiffusionControlNetPipeline(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=model,
            scheduler=scheduler,
            feature_extractor=None,
            safety_checker=None,
            requires_safety_checker=False,
            progress_bar_config=progress_bar_config,
            embedder_t5=embedder_t5,
            infer_mode=infer_mode,
            controlnet=controlnet,
        )
    pipeline = pipeline.to(device)

    return pipeline, sampler


class End2End(object):
    def __init__(self, args, models_root_path):
        self.args = args
        self.controlnet = None
        # Check arguments
        t2i_root_path = Path(models_root_path) / "t2i"
        self.root = t2i_root_path
        logger.info(f"Got text-to-image model root path: {t2i_root_path}")

        # Set device and disable gradient
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        torch.set_grad_enabled(False)
        # Disable BertModel logging checkpoint info
        tf_logger.setLevel("ERROR")

        # ========================================================================
        logger.info(f"Loading CLIP Text Encoder...")
        text_encoder_path = self.root / "clip_text_encoder"
        self.clip_text_encoder = BertModel.from_pretrained(
            str(text_encoder_path), False, revision=None
        ).to(self.device)
        logger.info(f"Loading CLIP Text Encoder finished")

        # ========================================================================
        logger.info(f"Loading CLIP Tokenizer...")
        tokenizer_path = self.root / "tokenizer"
        self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
        logger.info(f"Loading CLIP Tokenizer finished")

        # ========================================================================
        logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
        t5_text_encoder_path = self.root / "mt5"
        embedder_t5 = MT5Embedder(
            t5_text_encoder_path, torch_dtype=torch.float16, max_length=256
        )
        self.embedder_t5 = embedder_t5
        self.embedder_t5.model.to(self.device)  # Only move encoder to device
        logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")

        # ========================================================================
        logger.info(f"Loading VAE...")
        vae_path = self.root / "sdxl-vae-fp16-fix"
        self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
        logger.info(f"Loading VAE finished")

        # ========================================================================
        # Create model structure and load the checkpoint
        logger.info(f"Building HunYuan-DiT model...")
        model_config = HUNYUAN_DIT_CONFIG[self.args.model]
        self.patch_size = model_config["patch_size"]
        self.head_size = model_config["hidden_size"] // model_config["num_heads"]
        self.resolutions, self.freqs_cis_img = (
            self.standard_shapes()
        )  # Used for TensorRT models
        self.image_size = _to_tuple(self.args.image_size)
        latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)

        self.infer_mode = self.args.infer_mode
        if self.infer_mode in ["fa", "torch"]:
            # Build model structure
            self.model = (
                HunYuanDiT(
                    self.args,
                    input_size=latent_size,
                    **model_config,
                    log_fn=logger.info,
                )
                .half()
                .to(self.device)
            )  # Force to use fp16

            # Load model checkpoint
            self.load_torch_weights()
            lora_ckpt = args.lora_ckpt
            if lora_ckpt is not None and lora_ckpt != "":
                logger.info(f"Loading Lora checkpoint {lora_ckpt}...")

                self.model.load_adapter(lora_ckpt)
                self.model.merge_and_unload()

            # self.model.eval()
            logger.info(f"Loading torch model finished")
        elif self.infer_mode == "trt":
            from .modules.trt.hcf_model import TRTModel

            trt_dir = self.root / "model_trt"
            engine_dir = trt_dir / "engine"
            plugin_path = trt_dir / "fmha_plugins/10.1_plugin_cuda11/fMHAPlugin.so"
            model_name = "model_onnx"

            logger.info(f"Loading TensorRT model {engine_dir}/{model_name}...")
            self.model = TRTModel(
                model_name=model_name,
                engine_dir=str(engine_dir),
                image_height=TRT_MAX_HEIGHT,
                image_width=TRT_MAX_WIDTH,
                text_maxlen=args.text_len,
                embedding_dim=args.text_states_dim,
                plugin_path=str(plugin_path),
                max_batch_size=TRT_MAX_BATCH_SIZE,
            )
            logger.info(f"Loading TensorRT model finished")
        else:
            raise ValueError(f"Unknown infer_mode: {self.infer_mode}")

        # ========================================================================
        # Build inference pipeline. We use a customized StableDiffusionPipeline.
        logger.info(f"Loading inference pipeline...")
        self.pipeline, self.sampler = self.load_sampler()
        logger.info(f"Loading pipeline finished")

        # ========================================================================
        self.default_negative_prompt = NEGATIVE_PROMPT
        logger.info("==================================================")
        logger.info(f"                Model is ready.                  ")
        logger.info("==================================================")

    def load_torch_weights(self):
        load_key = self.args.load_key
        if self.args.dit_weight is not None:
            dit_weight = Path(self.args.dit_weight)
            if dit_weight.is_dir():
                files = list(dit_weight.glob("*.pt"))
                if len(files) == 0:
                    raise ValueError(f"No model weights found in {dit_weight}")
                if str(files[0]).startswith("pytorch_model_"):
                    model_path = dit_weight / f"pytorch_model_{load_key}.pt"
                    bare_model = True
                elif any(str(f).endswith("_model_states.pt") for f in files):
                    files = [f for f in files if str(f).endswith("_model_states.pt")]
                    model_path = files[0]
                    if len(files) > 1:
                        logger.warning(
                            f"Multiple model weights found in {dit_weight}, using {model_path}"
                        )
                    bare_model = False
                else:
                    raise ValueError(
                        f"Invalid model path: {dit_weight} with unrecognized weight format: "
                        f"{list(map(str, files))}. When given a directory as --dit-weight, only "
                        f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
                        f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
                        f"specific weight file, please provide the full path to the file."
                    )
            elif dit_weight.is_file():
                model_path = dit_weight
                bare_model = "unknown"
            else:
                raise ValueError(f"Invalid model path: {dit_weight}")
        else:
            model_dir = self.root / "model"
            model_path = model_dir / f"pytorch_model_{load_key}.pt"
            bare_model = True

        if not model_path.exists():
            raise ValueError(f"model_path not exists: {model_path}")
        logger.info(f"Loading torch model {model_path}...")
        if model_path.suffix == ".safetensors":
            raise NotImplementedError(f"Loading safetensors is not supported yet.")
        else:
            # Assume it's a single weight file in the *.pt format.
            state_dict = torch.load(
                model_path, map_location=lambda storage, loc: storage
            )

        if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
            bare_model = False
        if bare_model is False:
            if load_key in state_dict:
                state_dict = state_dict[load_key]
            else:
                raise KeyError(
                    f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
                    f"are: {list(state_dict.keys())}."
                )

        if "style_embedder.weight" in state_dict and not hasattr(
            self.model, "style_embedder"
        ):
            raise ValueError(
                f"You might be attempting to load the weights of HunYuanDiT version <= 1.1. You need "
                f"to set `--use-style-cond --size-cond 1024 1024 --beta-end 0.03` to adapt to these weights."
                f"Alternatively, you can use weights of version >= 1.2, which no longer depend on "
                f"these two parameters."
            )
        if "style_embedder.weight" not in state_dict and hasattr(
            self.model, "style_embedder"
        ):
            raise ValueError(
                f"You might be attempting to load the weights of HunYuanDiT version >= 1.2. You need "
                f"to remove `--use-style-cond` and `--size-cond 1024 1024` to adapt to these weights."
            )
        # Don't set strict=False. Always explicitly check the state_dict.
        self.model.load_state_dict(state_dict, strict=True)
        self.base_mode = state_dict
        self.model_name = os.path.basename(model_path)

    def load_controlnet(self, controlnet=None, sampler=None):
        if controlnet is not None:
            latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
            model_config = HUNYUAN_DIT_CONFIG[self.args.model]
            logger.info(f"Loading controlnet ")
            if self.infer_mode in ["fa", "torch"]:
                # Build model structure
                controlnet_dir = self.root / "controlnet"
                controlnet_path = controlnet_dir / f"{controlnet}.pt"

                if not controlnet_path.exists():
                    raise ValueError(f"controlnet_path not exists: {controlnet_path}")
                self.controlnet = (
                    HunYuanControlNet(
                        self.args,
                        input_size=latent_size,
                        **model_config,
                        log_fn=logger.info,
                    )
                    .half()
                    .to(self.device)
                )

                controlnet_state_dict = torch.load(controlnet_path)
                self.controlnet.load_state_dict(controlnet_state_dict, strict=False)
                self.controlnet.eval()
                self.pipeline, self.sampler = self.load_sampler(sampler=sampler)
                logger.info(f"Loading controlnet finished")

    def load_sampler(self, sampler=None):
        pipeline, sampler = get_pipeline(
            self.args,
            self.vae,
            self.clip_text_encoder,
            self.tokenizer,
            self.model,
            device=self.device,
            rank=0,
            embedder_t5=self.embedder_t5,
            infer_mode=self.infer_mode,
            sampler=sampler,
            controlnet=self.controlnet,
        )
        return pipeline, sampler

    def calc_rope(self, height, width):
        th = height // 8 // self.patch_size
        tw = width // 8 // self.patch_size
        base_size = 512 // 8 // self.patch_size
        start, stop = get_fill_resize_and_crop((th, tw), base_size)
        sub_args = [start, stop, (th, tw)]
        rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
        return rope

    def pixel_perfect_resolution(
        self,
        image,
        target_H: int,
        target_W: int,
        resize_mode,
    ):
        if resize_mode == "Resize":
            return image.resize((target_W, target_H))
        elif resize_mode == "Crop_and_Resize":
            original_width, original_height = image.size
            width_ratio = target_W / original_width
            height_ratio = target_H / original_height

            # Resize to fit the smaller dimension
            if width_ratio > height_ratio:
                new_width = target_W
                new_height = int(original_height * width_ratio)
            else:
                new_width = int(original_width * height_ratio)
                new_height = target_H
            resized_image = image.resize((new_width, new_height))

            # Crop from the center
            left = (new_width - target_W) / 2
            top = (new_height - target_H) / 2
            right = (new_width + target_W) / 2
            bottom = (new_height + target_H) / 2
            cropped_image = resized_image.crop((left, top, right, bottom))

            return cropped_image
        elif resize_mode == "Resize_and_Fill":
            original_width, original_height = image.size
            width_ratio = target_W / original_width
            height_ratio = target_H / original_height

            # Resize to fit the larger dimension
            if width_ratio < height_ratio:
                new_width = target_W
                new_height = int(original_height * width_ratio)
            else:
                new_width = int(original_width * height_ratio)
                new_height = target_H
            resized_image = image.resize((new_width, new_height))

            # Create a new image with the target size and fill with white
            background_color = (0, 0, 0)  # Default background color: white
            new_image = Image.new("RGB", (target_W, target_H), background_color)

            # Paste the resized image onto the center of the new image
            paste_left = int((target_W - new_width) / 2)
            paste_top = int((target_H - new_height) / 2)
            new_image.paste(resized_image, (paste_left, paste_top))

            return new_image

    def standard_shapes(self):
        resolutions = ResolutionGroup()
        freqs_cis_img = {}
        for reso in resolutions.data:
            freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
        return resolutions, freqs_cis_img

    def predict(
        self,
        user_prompt,
        height=1024,
        width=1024,
        seed=None,
        enhanced_prompt=None,
        negative_prompt=None,
        infer_steps=100,
        guidance_scale=6,
        batch_size=1,
        src_size_cond=(1024, 1024),
        sampler=None,
        use_style_cond=False,
        controlnet=None,
        control_weight=None,
        image=None,
        lora_ctrls=None,
    ):
        # ========================================================================
        # Arguments: seed
        # ========================================================================
        if seed is None or seed == 0:
            seed = random.randint(0, 1_000_000)
        if not isinstance(seed, int):
            raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
        generator = set_seeds(seed, device=self.device)
        # ========================================================================
        # Arguments: target_W, target_H
        # ========================================================================
        if width <= 0 or height <= 0:
            raise ValueError(
                f"`height` and `width` must be positive integers, got height={height}, width={width}"
            )
        logger.info(f"Input (height, width) = ({height}, {width})")
        if self.infer_mode in ["fa", "torch"]:
            # We must force height and width to align to 16 and to be an integer.
            target_H = int((height // 16) * 16)
            target_W = int((width // 16) * 16)
            logger.info(f"Align to 16: (height, width) = ({target_H}, {target_W})")
        elif self.infer_mode == "trt":
            target_W, target_H = get_standard_shape(width, height)
            logger.info(
                f"Align to standard shape: (height, width) = ({target_H}, {target_W})"
            )
        else:
            raise ValueError(f"Unknown infer_mode: {self.infer_mode}")

        # ========================================================================
        # Arguments: prompt, new_prompt, negative_prompt
        # ========================================================================
        if not isinstance(user_prompt, str):
            raise TypeError(
                f"`user_prompt` must be a string, but got {type(user_prompt)}"
            )
        user_prompt = user_prompt.strip()
        prompt = user_prompt
        if lora_ctrls is not None and lora_ctrls != []:
            for loras in lora_ctrls:
                model = loras["model"]
                weight = loras["weight"]
                logger.debug(f"model:{model}")
                logger.debug(f"weight:{weight}")
                logger.info(f"Loading Lora checkpoint {model}...")
                lora_path = str(self.root / "lora" / f"{model}.safetensors")
                logger.info(f"Loading Lora checkpoint {lora_path}...")
                lora = self.pipeline.lora_state_dict(lora_path)[0]
                lora = load_lora(lora, self.model.state_dict().copy(), weight)
                self.model.load_state_dict(lora)

            # self.model = self.load_lora(lora_dit, self.model.state_dict)
        if enhanced_prompt is not None:
            if not isinstance(enhanced_prompt, str):
                raise TypeError(
                    f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}"
                )
            enhanced_prompt = enhanced_prompt.strip()
            prompt = enhanced_prompt

        # negative prompt
        if negative_prompt is None or negative_prompt == "":
            negative_prompt = self.default_negative_prompt
        if not isinstance(negative_prompt, str):
            raise TypeError(
                f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
            )

        # ========================================================================
        # Arguments: style. (A fixed argument. Don't Change it.)
        # ========================================================================
        if use_style_cond:
            # Only for hydit <= 1.1
            style = torch.as_tensor([0, 0] * batch_size, device=self.device)
        else:
            style = None

        # ========================================================================
        # Inner arguments: image_meta_size (Please refer to SDXL.)
        # ========================================================================
        if src_size_cond is None:
            size_cond = None
            image_meta_size = None
        else:
            # Only for hydit <= 1.1
            if isinstance(src_size_cond, int):
                src_size_cond = [src_size_cond, src_size_cond]
            if not isinstance(src_size_cond, (list, tuple)):
                raise TypeError(
                    f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}"
                )
            if len(src_size_cond) != 2:
                raise ValueError(
                    f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}"
                )
            size_cond = list(src_size_cond) + [target_W, target_H, 0, 0]
            image_meta_size = torch.as_tensor(
                [size_cond] * 2 * batch_size, device=self.device
            )

        if controlnet != "None" and controlnet != self.controlnet:
            self.load_controlnet(controlnet, sampler)
        elif controlnet == "None" and controlnet != self.controlnet:
            torch.cuda.empty_cache()
            logger.info("更新管道")
            self.pipeline, self.sampler = get_pipeline(
                self.args,
                self.vae,
                self.clip_text_encoder,
                self.tokenizer,
                self.model,
                device=self.device,
                rank=0,
                embedder_t5=self.embedder_t5,
                infer_mode=self.infer_mode,
                sampler=sampler,
                controlnet=None,
            )
        # ========================================================================
        start_time = time.time()
        logger.debug(
            f"""
                       prompt: {user_prompt}
              enhanced prompt: {enhanced_prompt}
                         seed: {seed}
              (height, width): {(target_H, target_W)}
              negative_prompt: {negative_prompt}
                   batch_size: {batch_size}
               guidance_scale: {guidance_scale}
                  infer_steps: {infer_steps}
              image_meta_size: {size_cond} 
               controlnet:{controlnet}
                control_weight={control_weight}
                active_loras={lora_ctrls}
        """
        )
        reso = f"{target_H}x{target_W}"
        if reso in self.freqs_cis_img:
            freqs_cis_img = self.freqs_cis_img[reso]
        else:
            freqs_cis_img = self.calc_rope(target_H, target_W)

        if sampler is not None and sampler != self.sampler:
            self.pipeline, self.sampler = self.load_sampler(sampler)

        if controlnet != "None" and image is not None:
            samples = self.pipeline(
                height=target_H,
                width=target_W,
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_images_per_prompt=batch_size,
                guidance_scale=guidance_scale,
                num_inference_steps=infer_steps,
                image_meta_size=image_meta_size,
                style=style,
                return_dict=False,
                generator=generator,
                freqs_cis_img=freqs_cis_img,
                use_fp16=self.args.use_fp16,
                image=image,
                learn_sigma=self.args.learn_sigma,
                control_weight=control_weight,
            )[0]
        else:
            samples = self.pipeline(
                height=target_H,
                width=target_W,
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_images_per_prompt=batch_size,
                guidance_scale=guidance_scale,
                num_inference_steps=infer_steps,
                image_meta_size=image_meta_size,
                style=style,
                return_dict=False,
                generator=generator,
                freqs_cis_img=freqs_cis_img,
                use_fp16=self.args.use_fp16,
                learn_sigma=self.args.learn_sigma,
            )[0]
        torch.cuda.empty_cache()
        gen_time = time.time() - start_time
        logger.debug(f"Success, time: {gen_time}")
        if lora_ctrls is not None and lora_ctrls != []:
            self.model.load_state_dict(self.base_mode, strict=True)
        return {
            "images": samples,
            "seed": seed,
        }
